/*
 * Copyright © 2023 Bas Nieuwenhuizen
 *
 * SPDX-License-Identifier: MIT
 */

#include "nir_builder.h"
#include "radv_nir.h"

/* This pass lowers cooperative matrix.
 *
 * On GFX11, the A&B matrices needs to be replicated, lanes 0..15 are replicated
 * to 16..31 and for wave64 also into lanes 32..47 and 48..63. A&B matrices are
 * always vectors of 16 elements.
 *
 * On GFX12, there is no data replication and the matrices layout is described
 * as below:
 *
 * Wave32:
 * A&B:
 *         0..15  | 16..31 (lanes)
 * v0 lo:  row 0  | row 4
 * v0 hi:  row 1  | row 5
 * v1 lo:  row 2  | row 6
 * v1 hi:  row 3  | row 7
 * v2 lo:  row 8  | row 12
 * v2 hi:  row 9  | row 13
 * v3 lo:  row 10 | row 14
 * v3 hi:  row 11 | row 15
 *
 * C:
 *         0..15  | 16..31 (lanes)
 * v0 lo:  row 0  | row 8
 * v0 hi:  row 1  | row 9
 * v1 lo:  row 2  | row 10
 * v1 hi:  row 3  | row 11
 * v2 lo:  row 4  | row 12
 * v2 hi:  row 5  | row 13
 * v3 lo:  row 6  | row 14
 * v3 hi:  row 7  | row 15
 *
 * Wave64:
 * A&B:
 *         0..15 | 16..31 | 32..47 | 48..63 (lanes)
 * v0 lo:  row 0 | row 4  | row 8  | row 12
 * v0 hi:  row 1 | row 5  | row 9  | row 13
 * v1 lo:  row 2 | row 6  | row 10 | row 14
 * v1 hi:  row 3 | row 7  | row 11 | row 15
 *
 * C:
 *         0..15 | 16..31 | 32..47 | 48..63 (lanes)
 * v0 lo:  row 0 | row 8  | row 4  | row 12
 * v0 hi:  row 1 | row 9  | row 5  | row 13
 * v1 lo:  row 2 | row 10 | row 6  | row 14
 * v1 hi:  row 3 | row 11 | row 7  | row 15
 */

typedef struct {
   enum amd_gfx_level gfx_level;
   unsigned wave_size;
} lower_cmat_params;

static unsigned
radv_nir_cmat_length(struct glsl_cmat_description desc, const lower_cmat_params *params)
{
   if (params->gfx_level >= GFX12) {
      assert(desc.cols == 16 && desc.rows == 16);
      return 256 / params->wave_size;
   } else {
      return desc.use != GLSL_CMAT_USE_ACCUMULATOR
                ? 16
                : (desc.cols * desc.rows / params->wave_size * 32 / glsl_base_type_bit_size(desc.element_type));
   }
}

static unsigned
radv_nir_cmat_length_mul(struct glsl_cmat_description desc, const lower_cmat_params *params)
{
   if (params->gfx_level >= GFX12) {
      return 1;
   } else {
      /* For C matrices we have 1 VGPR per element even if the element type is
       * < 32 bits. So with 8 fp16 elements we implement that with a f16vec16.
       * We then use the coefficient generated by this function to figure out
       * how many elements we really have.
       */
      return desc.use == GLSL_CMAT_USE_ACCUMULATOR ? (32 / glsl_base_type_bit_size(desc.element_type)) : 1;
   }
}

static unsigned
radv_nir_cmat_bits(struct glsl_cmat_description desc)
{
   return glsl_base_type_bit_size(desc.element_type);
}

static nir_def *
radv_nir_load_cmat(nir_builder *b, const lower_cmat_params *params, nir_def *src)
{
   nir_deref_instr *deref = nir_instr_as_deref(src->parent_instr);
   struct glsl_cmat_description desc = *glsl_get_cmat_description(deref->type);
   return nir_build_load_deref(b, radv_nir_cmat_length(desc, params), glsl_base_type_bit_size(desc.element_type), src,
                               0);
}

static const struct glsl_type *
radv_nir_translate_matrix_type(const struct glsl_type *orig_type, struct hash_table *type_map,
                               const lower_cmat_params *params)
{
   struct hash_entry *entry = _mesa_hash_table_search(type_map, orig_type);
   if (entry) {
      return entry->data;
   } else if (glsl_type_is_cmat(orig_type)) {
      struct glsl_cmat_description desc = *glsl_get_cmat_description(orig_type);
      unsigned length = radv_nir_cmat_length(desc, params);

      return glsl_vector_type(desc.element_type, length);
   } else if (glsl_type_is_array(orig_type)) {
      const struct glsl_type *elem_type = glsl_get_array_element(orig_type);
      const struct glsl_type *new_elem_type = radv_nir_translate_matrix_type(elem_type, type_map, params);

      if (elem_type == new_elem_type)
         return orig_type;

      return glsl_array_type(new_elem_type, glsl_get_length(orig_type), glsl_get_explicit_stride(orig_type));
   } else if (glsl_type_is_struct(orig_type)) {
      unsigned num_fields = glsl_get_length(orig_type);

      bool change = false;
      for (unsigned i = 0; i < num_fields; ++i) {
         const struct glsl_type *field_type = glsl_get_struct_field(orig_type, i);
         const struct glsl_type *new_field_type = radv_nir_translate_matrix_type(field_type, type_map, params);

         if (field_type != new_field_type) {
            change = true;
            break;
         }
      }

      if (!change)
         return orig_type;

      struct glsl_struct_field *fields = malloc(sizeof(struct glsl_struct_field) * num_fields);

      for (unsigned i = 0; i < num_fields; ++i) {
         fields[i] = *glsl_get_struct_field_data(orig_type, i);

         fields[i].type = radv_nir_translate_matrix_type(fields[i].type, type_map, params);
      }

      const struct glsl_type *ret =
         glsl_struct_type(fields, num_fields, glsl_get_type_name(orig_type), glsl_struct_type_is_packed(orig_type));
      free(fields);

      _mesa_hash_table_insert(type_map, orig_type, (void *)ret);
      return ret;
   } else
      return orig_type;
}

static nir_def *
radv_get_base_row(nir_builder *b, struct glsl_cmat_description desc, const lower_cmat_params *params,
                  nir_def *local_idx)
{
   nir_def *base_row;

   if (params->gfx_level >= GFX12) {
      base_row = nir_udiv_imm(b, local_idx, 16);

      if (desc.use == GLSL_CMAT_USE_ACCUMULATOR && params->wave_size == 64) {
         /* Switch rows from lanes 16..31 to 32..47, offset right shift by -2
          * to get implicit * 4.
          */
         base_row = nir_ushr_imm(b, nir_bitfield_reverse(b, base_row), 30 - 2);
      } else {
         base_row = nir_imul_imm(b, base_row, desc.use == GLSL_CMAT_USE_ACCUMULATOR && params->wave_size == 32 ? 8 : 4);
      }
   } else {
      base_row = desc.use == GLSL_CMAT_USE_ACCUMULATOR ? nir_udiv_imm(b, local_idx, 16) : nir_imm_int(b, 0);
   }

   return base_row;
}

bool
radv_nir_lower_cooperative_matrix(nir_shader *shader, enum amd_gfx_level gfx_level, unsigned wave_size)
{
   bool progress = false;

   if (!shader->info.cs.has_cooperative_matrix)
      return false;

   const lower_cmat_params params = {
      .gfx_level = gfx_level,
      .wave_size = wave_size,
   };

   struct nir_function *func = (struct nir_function *)exec_list_get_head_const(&shader->functions);
   struct hash_table *type_map = _mesa_pointer_hash_table_create(NULL);

   nir_foreach_variable_with_modes (var, shader, nir_var_shader_temp) {
      const struct glsl_type *new_type = radv_nir_translate_matrix_type(var->type, type_map, &params);
      if (new_type != var->type) {
         var->type = new_type;
         progress = true;
      }
   }

   nir_foreach_function_temp_variable (var, func->impl) {
      const struct glsl_type *new_type = radv_nir_translate_matrix_type(var->type, type_map, &params);
      if (new_type != var->type) {
         var->type = new_type;
         progress = true;
      }
   }

   nir_builder b = nir_builder_create(func->impl);

   /* Iterate in reverse order so that lowering can still use the matrix types from the derefs before we change it. */
   nir_foreach_block_reverse (block, func->impl) {
      nir_foreach_instr_reverse_safe (instr, block) {
         b.cursor = nir_before_instr(instr);

         switch (instr->type) {
         case nir_instr_type_intrinsic: {
            nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
            switch (intr->intrinsic) {
            case nir_intrinsic_cmat_length: {
               struct glsl_cmat_description desc = nir_intrinsic_cmat_desc(intr);
               unsigned len = radv_nir_cmat_length(desc, &params) / radv_nir_cmat_length_mul(desc, &params);
               nir_def_rewrite_uses(&intr->def, nir_imm_int(&b, len));
               nir_instr_remove(instr);
               progress = true;
               break;
            }
            case nir_intrinsic_cmat_extract: {
               nir_deref_instr *src_deref = nir_instr_as_deref(intr->src[0].ssa->parent_instr);
               struct glsl_cmat_description desc = *glsl_get_cmat_description(src_deref->type);
               nir_def *src0 = radv_nir_load_cmat(&b, &params, intr->src[0].ssa);

               nir_def *index = intr->src[1].ssa;
               index = nir_imul_imm(&b, index, radv_nir_cmat_length_mul(desc, &params));

               nir_def *elem = nir_vector_extract(&b, src0, index);

               nir_def_rewrite_uses(&intr->def, elem);
               nir_instr_remove(instr);
               progress = true;
               break;
            }
            case nir_intrinsic_cmat_insert: {
               nir_def *src1 = radv_nir_load_cmat(&b, &params, intr->src[2].ssa);
               nir_deref_instr *dst_deref = nir_instr_as_deref(intr->src[0].ssa->parent_instr);
               struct glsl_cmat_description desc = *glsl_get_cmat_description(dst_deref->type);
               nir_def *index = intr->src[3].ssa;
               index = nir_imul_imm(&b, index, radv_nir_cmat_length_mul(desc, &params));

               nir_def *elem = intr->src[1].ssa;
               nir_def *r = nir_vector_insert(&b, src1, elem, index);
               nir_store_deref(&b, dst_deref, r, nir_component_mask(r->num_components));
               nir_instr_remove(instr);
               progress = true;
               break;
            }
            case nir_intrinsic_cmat_construct: {
               nir_deref_instr *dst_deref = nir_instr_as_deref(intr->src[0].ssa->parent_instr);
               struct glsl_cmat_description desc = *glsl_get_cmat_description(dst_deref->type);
               nir_def *elem = intr->src[1].ssa;

               nir_def *r = nir_replicate(&b, elem, radv_nir_cmat_length(desc, &params));

               nir_store_deref(&b, dst_deref, r, nir_component_mask(r->num_components));
               nir_instr_remove(instr);
               progress = true;
               break;
            }
            case nir_intrinsic_cmat_load: {
               nir_deref_instr *dst_deref = nir_instr_as_deref(intr->src[0].ssa->parent_instr);
               struct glsl_cmat_description desc = *glsl_get_cmat_description(dst_deref->type);
               enum glsl_matrix_layout layout = nir_intrinsic_matrix_layout(intr);

               nir_deref_instr *deref = nir_instr_as_deref(intr->src[1].ssa->parent_instr);
               nir_def *stride = intr->src[2].ssa;

               nir_def *local_idx = nir_load_subgroup_invocation(&b);
               nir_def *inner_idx = nir_iand_imm(&b, local_idx, 15);

               /* A input is transposed */
               if (desc.use == GLSL_CMAT_USE_A)
                  layout = layout == GLSL_MATRIX_LAYOUT_COLUMN_MAJOR ? GLSL_MATRIX_LAYOUT_ROW_MAJOR
                                                                     : GLSL_MATRIX_LAYOUT_COLUMN_MAJOR;

               unsigned length = radv_nir_cmat_length(desc, &params);
               unsigned mul = radv_nir_cmat_length_mul(desc, &params);
               unsigned lanes_per_iter = desc.use == GLSL_CMAT_USE_ACCUMULATOR ? params.wave_size : 16;
               nir_def *vars[16];
               if (mul > 1) {
                  for (unsigned i = 0; i < length; ++i)
                     if (i % mul != 0)
                        vars[i] = nir_undef(&b, 1, glsl_base_type_bit_size(desc.element_type));
               }

               unsigned idx_bits = deref->def.bit_size;
               nir_def *base_row = radv_get_base_row(&b, desc, &params, local_idx);

               for (unsigned i = 0; i < length / mul; ++i) {
                  nir_def *col_offset = inner_idx;
                  nir_def *row_offset;
                  uint32_t row_iter;

                  if (gfx_level >= GFX12) {
                     row_iter = desc.use != GLSL_CMAT_USE_ACCUMULATOR && wave_size == 32 ? i + (i & 4) : i;
                  } else {
                     row_iter = i * lanes_per_iter / 16;
                  }

                  row_offset = nir_iadd_imm(&b, base_row, row_iter);

                  if (layout == GLSL_MATRIX_LAYOUT_ROW_MAJOR) {
                     nir_def *tmp = col_offset;
                     col_offset = row_offset;
                     row_offset = tmp;
                  }

                  col_offset = nir_imul(&b, col_offset, stride);

                  col_offset = nir_u2uN(&b, col_offset, idx_bits);
                  row_offset = nir_u2uN(&b, row_offset, idx_bits);

                  nir_deref_instr *iter_deref = nir_build_deref_ptr_as_array(&b, deref, col_offset);
                  iter_deref =
                     nir_build_deref_cast(&b, &iter_deref->def, deref->modes, glsl_scalar_type(desc.element_type),
                                          glsl_base_type_bit_size(desc.element_type) / 8);
                  iter_deref = nir_build_deref_ptr_as_array(&b, iter_deref, row_offset);

                  vars[i * mul] = nir_load_deref(&b, iter_deref);
               }

               nir_def *mat = nir_vec(&b, vars, length);
               nir_store_deref(&b, dst_deref, mat, nir_component_mask(mat->num_components));
               nir_instr_remove(instr);
               progress = true;
               break;
            }
            case nir_intrinsic_cmat_store: {
               enum glsl_matrix_layout layout = nir_intrinsic_matrix_layout(intr);

               nir_deref_instr *deref = nir_instr_as_deref(intr->src[0].ssa->parent_instr);
               nir_def *src = intr->src[1].ssa;
               nir_def *stride = intr->src[2].ssa;

               nir_deref_instr *src_deref = nir_instr_as_deref(src->parent_instr);
               struct glsl_cmat_description desc = *glsl_get_cmat_description(src_deref->type);
               src = radv_nir_load_cmat(&b, &params, src);

               nir_def *local_idx = nir_load_subgroup_invocation(&b);

               if (gfx_level < GFX12 && desc.use != GLSL_CMAT_USE_ACCUMULATOR)
                  nir_push_if(&b, nir_ilt_imm(&b, local_idx, 16));

               nir_def *inner_idx = nir_iand_imm(&b, local_idx, 15);

               /* A input is transposed */
               if (desc.use == GLSL_CMAT_USE_A)
                  layout = layout == GLSL_MATRIX_LAYOUT_COLUMN_MAJOR ? GLSL_MATRIX_LAYOUT_ROW_MAJOR
                                                                     : GLSL_MATRIX_LAYOUT_COLUMN_MAJOR;

               unsigned length = radv_nir_cmat_length(desc, &params);
               unsigned mul = radv_nir_cmat_length_mul(desc, &params);
               unsigned lanes_per_iter = desc.use == GLSL_CMAT_USE_ACCUMULATOR ? params.wave_size : 16;
               nir_def *vars[16];
               for (unsigned i = 0; i < length; ++i)
                  vars[i] = nir_channel(&b, src, i);

               unsigned idx_bits = deref->def.bit_size;
               nir_def *base_row = radv_get_base_row(&b, desc, &params, local_idx);

               for (unsigned i = 0; i < length / mul; ++i) {
                  nir_def *col_offset = inner_idx;
                  nir_def *row_offset;
                  uint32_t row_iter;

                  if (gfx_level >= GFX12) {
                     row_iter = desc.use != GLSL_CMAT_USE_ACCUMULATOR && wave_size == 32 ? i + (i & 4) : i;
                  } else {
                     row_iter = i * lanes_per_iter / 16;
                  }

                  row_offset = nir_iadd_imm(&b, base_row, row_iter);

                  if (layout == GLSL_MATRIX_LAYOUT_ROW_MAJOR) {
                     nir_def *tmp = col_offset;
                     col_offset = row_offset;
                     row_offset = tmp;
                  }

                  col_offset = nir_imul(&b, col_offset, stride);

                  col_offset = nir_u2uN(&b, col_offset, idx_bits);
                  row_offset = nir_u2uN(&b, row_offset, idx_bits);

                  nir_deref_instr *iter_deref = nir_build_deref_ptr_as_array(&b, deref, col_offset);
                  iter_deref =
                     nir_build_deref_cast(&b, &iter_deref->def, deref->modes, glsl_scalar_type(desc.element_type),
                                          glsl_base_type_bit_size(desc.element_type) / 8);
                  iter_deref = nir_build_deref_ptr_as_array(&b, iter_deref, row_offset);

                  nir_store_deref(&b, iter_deref, vars[i * mul], 1);
               }

               if (gfx_level < GFX12 && desc.use != GLSL_CMAT_USE_ACCUMULATOR)
                  nir_pop_if(&b, NULL);

               nir_instr_remove(instr);
               progress = true;
               break;
            }
            case nir_intrinsic_cmat_muladd: {
               nir_def *A = radv_nir_load_cmat(&b, &params, intr->src[1].ssa);
               nir_def *B = radv_nir_load_cmat(&b, &params, intr->src[2].ssa);
               nir_def *C = radv_nir_load_cmat(&b, &params, intr->src[3].ssa);
               nir_def *ret;

               ret = nir_cmat_muladd_amd(&b, A, B, C, .saturate = nir_intrinsic_saturate(intr),
                                         .cmat_signed_mask = nir_intrinsic_cmat_signed_mask(intr));

               nir_store_deref(&b, nir_instr_as_deref(intr->src[0].ssa->parent_instr), ret,
                               nir_component_mask(ret->num_components));
               nir_instr_remove(instr);
               progress = true;
               break;
            }
            case nir_intrinsic_cmat_unary_op: {
               nir_deref_instr *dst_deref = nir_instr_as_deref(intr->src[0].ssa->parent_instr);
               nir_deref_instr *src_deref = nir_instr_as_deref(intr->src[1].ssa->parent_instr);
               struct glsl_cmat_description desc = *glsl_get_cmat_description(dst_deref->type);
               struct glsl_cmat_description src_desc = *glsl_get_cmat_description(src_deref->type);
               nir_def *src = radv_nir_load_cmat(&b, &params, intr->src[1].ssa);
               nir_op op = nir_intrinsic_alu_op(intr);

               if (gfx_level < GFX12 && glsl_base_type_bit_size(src_desc.element_type) == 16 &&
                   glsl_base_type_bit_size(desc.element_type) == 32 && desc.use == GLSL_CMAT_USE_ACCUMULATOR) {
                  nir_def *components[NIR_MAX_VEC_COMPONENTS];
                  for (unsigned i = 0; i * 2 < src->num_components; ++i) {
                     components[i] = nir_channel(&b, src, i * 2);
                  }
                  src = nir_vec(&b, components, src->num_components / 2);
               }

               nir_def *ret = nir_build_alu1(&b, op, src);

               if (gfx_level < GFX12 && glsl_base_type_bit_size(src_desc.element_type) == 32 &&
                   glsl_base_type_bit_size(desc.element_type) == 16 && desc.use == GLSL_CMAT_USE_ACCUMULATOR) {
                  nir_def *components[NIR_MAX_VEC_COMPONENTS];
                  for (unsigned i = 0; i < ret->num_components; ++i) {
                     components[i * 2] = nir_channel(&b, ret, i);
                     components[i * 2 + 1] = nir_undef(&b, 1, 16);
                  }
                  ret = nir_vec(&b, components, ret->num_components * 2);
               }

               nir_store_deref(&b, dst_deref, ret, nir_component_mask(ret->num_components));
               nir_instr_remove(instr);
               progress = true;
               break;
            }
            case nir_intrinsic_cmat_scalar_op: {
               nir_def *src1 = radv_nir_load_cmat(&b, &params, intr->src[1].ssa);
               nir_op op = nir_intrinsic_alu_op(intr);
               nir_def *ret = nir_build_alu2(&b, op, src1, intr->src[2].ssa);
               nir_store_deref(&b, nir_instr_as_deref(intr->src[0].ssa->parent_instr), ret,
                               nir_component_mask(ret->num_components));
               nir_instr_remove(instr);
               progress = true;
               break;
            }
            case nir_intrinsic_cmat_binary_op: {
               nir_def *src1 = radv_nir_load_cmat(&b, &params, intr->src[1].ssa);
               nir_def *src2 = radv_nir_load_cmat(&b, &params, intr->src[2].ssa);
               nir_op op = nir_intrinsic_alu_op(intr);
               nir_def *ret = nir_build_alu2(&b, op, src1, src2);
               nir_store_deref(&b, nir_instr_as_deref(intr->src[0].ssa->parent_instr), ret,
                               nir_component_mask(ret->num_components));
               nir_instr_remove(instr);
               progress = true;
               break;
            }
            case nir_intrinsic_cmat_bitcast: {
               nir_def *src1 = radv_nir_load_cmat(&b, &params, intr->src[1].ssa);
               nir_store_deref(&b, nir_instr_as_deref(intr->src[0].ssa->parent_instr), src1,
                               nir_component_mask(src1->num_components));
               nir_instr_remove(instr);
               progress = true;
               break;
            }
            case nir_intrinsic_cmat_copy: {
               nir_build_copy_deref(&b, intr->src[0].ssa, intr->src[1].ssa);
               nir_instr_remove(instr);
               progress = true;
               break;
            }
            default:
               continue;
            }
            break;
         }
         case nir_instr_type_deref: {
            nir_deref_instr *deref = nir_instr_as_deref(instr);
            const struct glsl_type *new_type = radv_nir_translate_matrix_type(deref->type, type_map, &params);
            if (new_type != deref->type) {
               deref->type = new_type;
               progress = true;
            }
            break;
         }
         default:
            continue;
         }
      }
   }

   _mesa_hash_table_destroy(type_map, NULL);

   return nir_progress(progress, func->impl, 0);
}
